#!/usr/bin/python
'''
Program name:
Structure Prototype Analysis Package (SPAP)

Description:
SPAP can analyze symmetry and compare similarity of a large number of atomic
structures. Typically, SPAP can process structures predicted by CALYPSO
(www.calypso.cn). We use spglib to analyze symmetry. Coordination
Characterization Function (CCF) is used to measure structural similarity. If
you use this program and method in your research, please read and cite the
following publication:
J. Phys. Condens. Matter 2017, 29, 165901.

Author:
Dr. Chuanxun Su

Email:
suchuanxun@163.cn / scx@calypso.cn

Dependency:
This program uses numpy, spglib, and ase (https://wiki.fysik.dtu.dk/ase/). You
can install them by one command: pip install numpy spglib ase

Usage:
SPAP can run at four different modes.
1.  To analyze CALYPSO structure prediction results, set i_mode=1. Run spap.py
    in results directory generated by CALYPSO.
2.  To calculate symmetry and compare similarity of a bunch of structures, set
    i_mode=2. Put structure files in struc directory. The files should be
    named *.cif, *.vasp (VASP format), or any name and format ase can read
    automatically.
3.  To read and analyze structures optimized by VASP, set i_mode=3. Define
    work_dir. It is the path where you put the optimized structures.
4.  To calculate symmetry and compare similarity of a list of structures, set
    i_mode=4. Assign a list of Atoms objects to structure_list.
You could customize other parameters at the end of this script.

Output:
I try to organize the output files similar to those generated by cak.py. So
that it's easier for user to get familiar with this program.
Analysis_Output.dat lists the information about the analyzed structures.
However we provide more information such as density, formula unit, and volume.
distance.dat stores the distances between structures and some other attributes.
analyzed_structures.db is in Atomic Simulation Environment (ASE) database
format for atoms. You can easily read and analyze these structures through
ASE. You can also add and store properties you are interested in. This
functionality is very useful for screening out good functional materials. SPAP can
also write structures in cif and VASP format in dir_* directory. Technically
speaking, SPAP can easily write any structure format supported by ASE.
'''

import os
import shutil
import argparse
import numpy as np
from ase import Atoms
from ase.db import connect
from ase.visualize import view
from ase.io import write, read
from ase.data import atomic_numbers
from spglib import standardize_cell, get_spacegroup
from .ccf import cal_inter_atomic_d, d2ccf, struc2ccf, cal_ccf_d, element_tag

# from spap.ccf import cal_inter_atomic_d, d2ccf, struc2ccf, cal_ccf_d, element_tag

try:
    import pickle
    import matplotlib.pyplot as plt
except:
    pass


def run_spap(symprec=0.1, e_range=0.4, total_struc=None, l_comp=True, threshold=None, r_cut_off=None, extend_r=1.0,
             ilat=2, ccf_step=0.02, l_db=True, l_cif=False, l_poscar=False, lprm=False, l_view=False, work_dir='./',
             structure_list=[], i_mode=1, lplt=False, ftype='CCF', apw=60.0, readf='XDATCAR', index=':', nsm=False,
             nfu=False):
    '''
    This function starts all the calculations.
    :param type:
    :param lplt:
    :param symprec: float
        This precision is used to analyze symmetry of atomic structures.
    :param e_range: float
        To define an energy range in which structures will be analyzed.
    :param total_struc: int
        This number of structures will be analyzed.
    :param l_comp: bool
        Whether to compare similarity of structures.
    :param threshold: float
        Threshold for similar/dissimilar boundary.
    :param r_cut_off: float
        Inter-atomic distances within this cut off radius will contribute to
        CCF.
    :param extend_r: float
        CCF will be calculated in the range of (0,r_cut_off+extend_r).
    :param ilat: int
        This parameter controls which method will be used to deal with lattice
        for comparing structural similarity.
        0 don't change lattice;
        1 equal particle number density;
        2 try equal particle number density and equal lattice.
    :param ccf_step: float
        Length of step for calculating CCF.
    :param l_db: bool
        Whether to write structures into ase (https://wiki.fysik.dtu.dk/ase/)
        database file.
    :param l_cif: bool
        Whether to write structures into cif files.
    :param l_poscar: bool
        Whether to write structures into files in VASP format.
    :param l_view: bool
        Whether to display the structures.
    :param work_dir: str
        Set working directory.
    :param structure_list: list of Atoms objects.
        Assign a list of Atoms objects to structure_list when using i_mode=4.
    :param i_mode: int
        Different functionality of SPAP.
        1 analyze CALYPSO prediction results;
        2 calculate symmetry and similarity of structures in struc directory;
        3 read and analyze structures optimized by VASP;
        4 calculate symmetry and similarity of structures in structure_list.
        8 read and analyze structures optimized by ABACUS;
    :return:
    '''
    print('Welcome using Structure Prototype Analysis Package (SPAP). The Coordination\n'
          'Characterization Function (CCF) is used to assess structural similarity. If\n'
          'you use this program and method in your research, please read and cite the\n'
          'following publication:\n'
          'J. Phys. Condens. Matter 2017, 29, 165901.\n')
    low_symm_er = 0.06
    min_e = 0.0
    max_e = 0.0
    e_list = []
    pbc = [True, True, True]
    dir_name = 'dir_' + str(symprec)
    debug = False
    if i_mode == 5 or i_mode == 7:
        l_comp = False
        if r_cut_off == None:
            r_cut_off = 9.0
        ccf_range = r_cut_off + extend_r
        r_vector = np.linspace(0.0, ccf_range, int(ccf_range / ccf_step) + 2)
    # os.chdir('D:\\share\\wks\\1_example\\results2')
    # os.chdir('D:\\share\\wks\\Examples\\1_example\\results')
    # os.chdir('D:\\share\\wks\\2mg\\results')
    os.chdir(work_dir)
    chemical_symbols = ''
    prediction_method = 'Unknown'
    if i_mode == 1:
        prediction_method = 'CALYPSO'
    cal_p = 'Not collected\n'
    pseudopotential = 'Not collected\n'
    pressure = 0.0  # in GPa

    if i_mode == 1 or i_mode == 6:
        with open('../input.dat', 'r') as input:
            prediction_parameters = input.readlines()
        input.closed
        for line in prediction_parameters:
            if (line.lstrip(' '))[0] == '#':
                continue
            elif 'NameOfAtoms' in line:
                # chemical_symbols = (line[line.find('=') + 1:-1]).split(' ')
                chemical_symbols = [symbol for symbol in (line[line.find('=') + 1:-1]).split(' ') if symbol != '']
                n_species = len(chemical_symbols)
            elif 'ICode' in line:
                if int(line[line.find('=') + 1:-1]) == 1:
                    calculator = 'VASP'
                elif int(line[line.find('=') + 1:-1]) == 7:
                    calculator = 'Gaussian'
                else:
                    calculator = 'Unknown'
            elif 'NumberOfLocalOptim' in line:
                nolo = (line[line.find('=') + 1:-1]).replace(' ', '')
            elif 'Cluster' in line:
                if (line[line.find('=') + 1:-1].lstrip())[0] == 'T':
                    pbc = [False, False, False]
                    dir_name = 'dir_origin'
                    if l_comp:
                        ilat = 0
                        if threshold == None:
                            threshold = 0.035
            elif '2D' in line:
                if (line[line.find('=') + 1:-1].lstrip())[0] == 'T':
                    pbc = [True, True, False]
                    if l_comp:
                        ilat = 0
                        if threshold == None:
                            threshold = 0.06
                    # print('Not supported yet.')
                    # exit()
            elif 'VSC' in line:
                if (line[line.find('=') + 1:-1].lstrip())[0] == 'T':
                    print('Not supported yet.')
                    exit()
            elif 'LSurface' in line:
                if (line[line.find('=') + 1:-1].lstrip())[0] == 'T':
                    print('Not supported yet.')
                    exit()
        if calculator == 'VASP':
            if os.path.exists('../INCAR_' + nolo):
                with open('../INCAR_' + nolo, 'r') as incar:
                    cal_p = incar.readlines()
                incar.closed
                for line in cal_p:
                    if (line.lstrip(' '))[0] == '#':
                        pass
                    elif 'PSTRESS' in line:
                        pressure = float(line[line.find('=') + 1:-1]) / 10
            if os.path.exists('../POTCAR'):
                with open('../POTCAR', 'r') as potcar:
                    # potcar_lines=potcar.readlines()
                    switch = False
                    first_line = True
                    for line in potcar:
                        if first_line:
                            pseudopotential = (line.lstrip()).rstrip()
                            # temp_n=1
                            # if n_species==1:
                            #     break
                            first_line = False
                        elif line[:4] == ' End':
                            switch = True
                        elif switch:
                            pseudopotential += ' || ' + (line.lstrip()).rstrip()
                            # temp_n+=1
                            # if temp_n==n_species:
                            #     break
                            switch = False
                potcar.closed
        with open('struct.dat', 'r') as struct:
            struct_lines = struct.readlines()
        struct.closed
        n_structure = 0

        print('Reading energy')
        for i, line in enumerate(struct_lines):
            if 'Energy=' in line:
                n_structure = n_structure + 1
                e_list.append([i, float(line[8:]), n_structure])
        if total_struc == None or total_struc < 1:
            min_e = min([x[1] for x in e_list])
            max_e = min_e + e_range
            e_list = [e for e in e_list if (e[1] < max_e) and (610612509.0 - e[1] > 0.1)]
        else:
            e_list = [e for e in e_list if 610612509.0 - e[1] > 0.1]
        e_list.sort(key=lambda x: x[1], reverse=False)
        if total_struc != None and total_struc > 0:
            e_list = e_list[:total_struc]
        total_struc = len(e_list)

        print('Reading structure')
        ill = []
        for ii in range(total_struc):
            element_numbers = (struct_lines[e_list[ii][0] + 3][9:-1]).split(' ')
            element_numbers = [n for n in element_numbers if n != '']
            chemical_formula = ''
            for j, symbol in enumerate(chemical_symbols):
                chemical_formula = chemical_formula + symbol + element_numbers[j]
            # cell=[[float(struct_lines[e_list[ii][0]+6+k][j-16:j]) for j in [17,33,49]] for k in [0,1,2]]
            # number_of_atom=sum([int(n) for n in element_numbers])
            # positions=[[float(struct_lines[e_list[ii][0]+13+k][j-12:j]) for j in [13,25,37]] for k in range(sum([int(n) for n in element_numbers]))]
            try:
                structure_list.append(Atoms(
                    chemical_formula,
                    cell=[[float(struct_lines[e_list[ii][0] + 6 + k][j - 16:j]) for j in [17, 33, 49]] for k in
                          [0, 1, 2]],
                    scaled_positions=[[float(struct_lines[e_list[ii][0] + 13 + k][j - 12:j]) for j in [13, 25, 37]] for
                                      k in range(sum([int(n) for n in element_numbers]))], pbc=pbc))
            except:
                ill.append(ii)
                print('Warning: structure in line {} in struct.dat was discarded.'.format(e_list[ii][0]))
            # structure_list[-1].e = e_list[ii][1]
            # structure_list[-1].n_structure = e_list[ii][2]
        e_list = [e_list[x] for x in range(total_struc) if not x in ill]
        total_struc = len(e_list)
        # structure_list.sort(key=lambda x:x.e,reverse=False)
        if i_mode == 6:
            return [[structure_list[i], e_list[i][1]] for i in range(total_struc)]
    elif i_mode == 2:
        print('Reading structure')
        total_struc = 0
        for root, dirs, files in os.walk('struc', topdown=True):
            for name in files:
                try:
                    structure_list.append(read(os.path.join(root, name)))
                    structure_list[-1].fnm = work_dir+os.path.join(root, name)[1:]
                    total_struc += 1
                    # print(name+'\n')
                except:
                    print('Cann\'t read this file: ' + os.path.join(root, name))

    elif i_mode == 3:
        calculator = 'VASP'
        print('Reading structure')
        i = 0
        for root, dirs, files in os.walk('.', topdown=True):
            for name in files:
                if i == 0 and ('INCAR' in name):
                    if calculator == 'VASP':
                        with open(os.path.join(root, name), 'r') as incar:
                            cal_p = incar.readlines()
                        # incar.closed
                        for line in cal_p:
                            if (line.lstrip(' '))[0] == '#':
                                pass
                            elif 'PSTRESS' in line:
                                pressure = float(line[line.find('=') + 1:-1]) / 10
                        if os.path.exists(os.path.join(root, 'POTCAR')):
                            with open(os.path.join(root, 'POTCAR'), 'r') as potcar:
                                # potcar_lines=potcar.readlines()
                                switch = False
                                first_line = True
                                for line in potcar:
                                    if first_line:
                                        pseudopotential = (line.lstrip()).rstrip()
                                        first_line = False
                                    elif line[:4] == ' End':
                                        switch = True
                                    elif switch:
                                        pseudopotential += ' || ' + (line.lstrip()).rstrip()
                                        switch = False
                            # potcar.closed
                if 'OUTCAR' in name:
                    try:
                        structure_list.append(read(os.path.join(root, name), format='vasp-out'))
                        structure_list[-1].fnm = work_dir+os.path.join(root, name)[1:]
                        i += 1
                        if debug:
                            e_list.append(
                                [0, structure_list[-1].calc.results['energy'] / len(structure_list[-1].numbers),
                                 i, os.path.join(root, name)])
                        else:
                            e_list.append(
                                [0, structure_list[-1].calc.results['energy'] / len(structure_list[-1].numbers),
                                 i])
                        # Be careful!!! Energy is changed!!!
                        structure_list[-1].calc.results['energy'] = e_list[-1][1]
                    except:
                        print('Cann\'t read this file: ' + os.path.join(root, name))
        if total_struc == None or total_struc < 1:
            min_e = min([x[1] for x in e_list])
            max_e = min_e + e_range
            e_list = [e for e in e_list if e[1] < max_e]
            structure_list = [s for s in structure_list if s.calc.results['energy'] < max_e]
        e_list.sort(key=lambda x: x[1], reverse=False)
        structure_list.sort(key=lambda x: x.calc.results['energy'], reverse=False)
        if total_struc != None and total_struc > 0:
            e_list = e_list[:total_struc]
            structure_list = structure_list[:total_struc]
        total_struc = len(e_list)
    elif i_mode == 4:
        total_struc = len(structure_list)
    elif i_mode == 5 or i_mode == 7:
        print('Reading structures')
        if i_mode == 5:
            structure_list = read(readf, index=index)
        averccf = {}
        total_struc = len(structure_list)
        if total_struc > 0:
            elet = element_tag(structure_list[0].numbers, irt=2)
            for i, struc in enumerate(structure_list):
                struc.ccf = struc2ccf(struc, r_cut_off, r_vector, apw, ftype)
                if i_mode == 7:
                    # plt_ccf(struc.ccf, r_vector, ftype,False)
                    pass
                if i_mode == 5:
                    for key in struc.ccf.keys():
                        if i == 0:
                            averccf[key] = struc.ccf[key]
                        else:
                            averccf[key] = averccf[key] + struc.ccf[key]
            if i_mode == 7:
                return [struc.ccf for struc in structure_list]
            for i, key in enumerate(averccf.keys()):
                averccf[key] = averccf[key] / total_struc
                if i == 0:
                    sumccf = averccf[key]
                else:
                    sumccf = sumccf + averccf[key]
            convccf = {}
            for key in averccf.keys():
                for i, rankn in enumerate([int(x) for x in key.split('_')]):
                    for itm in elet.items():
                        if itm[1][0] == rankn:
                            if i == 0:
                                newkey = itm[1][1]
                                break
                            else:
                                newkey += '-' + itm[1][1]
                                convccf[newkey] = averccf[key]
                                break
            convccf['Total'] = sumccf
            ccff = open('ccf.pickle', 'wb')
            pickle.dump(convccf, ccff)
            ccff.close()
            rvf = open('rvf.pickle', 'wb')
            pickle.dump(r_vector, rvf)
            rvf.close()
            keyl = []
            keyl.append('r')
            keyl += [x for x in convccf.keys()]
            ccfd = open('ccf.csv', 'w')
            ccfd.write('r')
            for x in keyl[1:]:
                ccfd.write(',{}'.format(x))
            for i in range(len(r_vector)):
                ccfd.write('\n')
                for x in keyl:
                    if x == 'r':
                        ccfd.write('{}'.format(r_vector[i]))
                    else:
                        ccfd.write(',{}'.format(convccf[x][i]))
            ccfd.close()
        plt_ccf(convccf, r_vector, ftype)
        return convccf
    
    # Add by shenzx 20200530
    elif (i_mode == 8):

        # Read structure and energy
        from .GetEnergyStru import ReadAbacus
        structure_list = ReadAbacus()
        total_struc = len(structure_list)
        e_list = [[0, 
                  structure_list[i].Energy,
                  i + 1] for i in range(total_struc)]

        if total_struc == None or total_struc < 1:
            min_e = min([x[1] for x in e_list])
            max_e = min_e + e_range
            e_list = [e for e in e_list if e[1] < max_e]
            structure_list = [s for s in structure_list if s.Energy < max_e]
        e_list.sort(key=lambda x: x[1], reverse=False)
        structure_list.sort(key=lambda x: x.Energy, reverse=False)
        if total_struc != None and total_struc > 0:
            e_list = e_list[:total_struc]
            structure_list = structure_list[:total_struc]
        total_struc = len(e_list)
                    
    if pbc == [True, True, True] or pbc == [True, True, False]:
        print('Analyzing symmetry')
        if r_cut_off == None:
            if pbc == [True, True, True]:
                r_cut_off = 9.0
            else:
                r_cut_off = 6.0
            # ccf_range = r_cut_off + extend_r
            # r_vector = np.linspace(0.0, ccf_range, int(ccf_range / ccf_step) + 2)
    elif r_cut_off == None:
        r_cut_off = 9.0
    ccf_range = r_cut_off + extend_r
    r_vector = np.linspace(0.0, ccf_range, int(ccf_range / ccf_step) + 2)
    space_g_l = []
    for structure in structure_list:
        if pbc == [True, True, True] or pbc == [True, True, False]:
            structure.conventional_cell = standardize_cell(
                (structure.cell, structure.get_scaled_positions(wrap=True), structure.numbers), symprec=symprec)
            if structure.conventional_cell == None:
                structure.conventional_cell = structure
                structure.space_group = 'NULL(0)'
                space_g_l.append(1)
            else:
                structure.conventional_cell = Atoms(cell=structure.conventional_cell[0],
                                                    scaled_positions=structure.conventional_cell[1],
                                                    numbers=structure.conventional_cell[2], pbc=pbc)
                structure.space_group = get_spacegroup(
                    (structure.cell, structure.get_scaled_positions(wrap=True), structure.numbers),
                    symprec=symprec).replace(' ', '')
                space_g_l.append(int(get_spg_n(structure.space_group)))
        elif pbc == [False, False, False]:
            structure.conventional_cell = structure
            structure.space_group = 'P1(1)'
            space_g_l.append(1)
        structure.conventional_cell.n_atom = len(structure.conventional_cell.numbers)
    if l_cif or l_poscar:
        if os.path.exists(dir_name):
            for root, dirs, files in os.walk(dir_name, topdown=False):
                for name in files:
                    os.remove(os.path.join(root, name))
                for name in dirs:
                    os.rmdir(os.path.join(root, name))
        else:
            os.mkdir(dir_name)
        for root, dirs, files in os.walk('./', topdown=True):
            for name in dirs:
                if name[:4] == 'dir_' and name != dir_name:
                    shutil.rmtree(name)
            break
    if lplt:
        for i, structure in enumerate(structure_list):
            if i != -1:
                # temp_struc=read('77_1.cif')
                # temp_struc.ccf=struc2ccf(temp_struc,r_cut_off,r_vector)
                # structure.conventional_cell.ccf=struc2ccf(structure.conventional_cell,r_cut_off,r_vector)
                # show2ccf(structure.conventional_cell.ccf,temp_struc.ccf,r_vector)
                # plt_ccf(structure.conventional_cell.ccf,r_vector)
                ccf = struc2ccf(structure, r_cut_off, r_vector)
                plt_ccf(ccf, r_vector)
        return None

    if l_comp:
        print('Comparing similarity')
        if threshold == None:
            threshold = 0.06
        struc_d = classify_structures([x.conventional_cell for x in structure_list],
                                      space_g_l, threshold, r_cut_off, ilat, r_vector, nsm, nfu)
        d_f = open('distance.dat', 'w')
        if i_mode == 1:
            d_f.write('{:>11}{:>14}{:>15} {:>13} {:>12} {:>13}\n'
                      .format('No.', 'Enthalpy', symprec, 'Prototype ID', 'Distance', 'Formula unit'))
        elif i_mode == 2 or i_mode == 4:
            d_f.write('{:>5}{:>15} {:>13} {:>12} {:>13}\n'
                      .format('No.', symprec, 'Prototype ID', 'Distance', 'Formula unit'))
        elif i_mode == 3 or i_mode == 8: # Add by shenzx 20200530
            d_f.write('{:>11}{:>14}{:>15} {:>13} {:>12} {:>13}\n'
                      .format('No.', 'Energy', symprec, 'Prototype ID', 'Distance', 'Formula unit'))
    else:
        struc_d = [[-2, 0.0] for i in range(total_struc)]

    print('Writing out put files')
    if i_mode == 2 or i_mode == 3 or i_mode == 8: # Add by shenzx 20200530
        fstcn = open('structure_source.dat', 'w')
        for i, struct in enumerate(structure_list):
            fstcn.write('{:<6} {}\n'.format(i + 1, struct.fnm))
        fstcn.close()
    with open('Analysis_Output.dat', 'w') as anal:
        if i_mode == 1:
            # format_a1='{:>11}{:>14}{:>15} {:>10} {:>12} {:>10}\n'
            # content_a1=('No.', 'Enthalpy', symprec, 'Density', 'Formula unit', 'Volume')
            anal.write('{:>11}{:>14}{:>15} {:>10} {:>12} {:>10}\n'.format('No.', 'Enthalpy', symprec, 'Density',
                                                                          'Formula unit', 'Volume'))
        elif i_mode == 2 or i_mode == 4:
            # format_a1='{:>11}{:>15} {:>10} {:>12} {:>10}\n'
            # content_a1=('No.', symprec, 'Density', 'Formula unit', 'Volume')
            anal.write('{:>5}{:>15} {:>10} {:>12} {:>10}\n'.format('No.', symprec, 'Density', 'Formula unit', 'Volume'))
        elif i_mode == 3 or i_mode == 8: # Add by shenzx 20200530
            anal.write('{:>11}{:>14}{:>15} {:>10} {:>12} {:>10}\n'.format('No.', 'Energy', symprec, 'Density',
                                                                          'Formula unit', 'Volume'))
        # anal.write(format_a1.format((content_a1)))
        if i_mode == 1 or i_mode == 3 or i_mode == 8: # Add by shenzx 20200530
            format_a = '{:>4} ({:>4}){:>14.5f}{:>15} {:>10.5f} {:>12} {:>10.3f}\n'
        elif i_mode == 2 or i_mode == 4:
            format_a = '{:>5}{:>15} {:>10.5f} {:>12} {:>10.3f}\n'
        for i, id_d in enumerate(struc_d):
            if id_d[0] == -2:
                temp = nele_ctype_fu(count_atoms(structure_list[i].conventional_cell.numbers))
                if i_mode == 1 or i_mode == 3 or i_mode == 8: # Add by shenzx 20200530
                    anal.write(format_a
                               .format(i + 1, e_list[i][2], e_list[i][1], structure_list[i].space_group,
                                       structure_list[i].conventional_cell.n_atom /
                                       structure_list[i].conventional_cell.get_volume(), temp[2],
                                       structure_list[i].conventional_cell.get_volume()))
                elif i_mode == 2 or i_mode == 4:
                    anal.write(format_a
                               .format(i + 1, structure_list[i].space_group,
                                       structure_list[i].conventional_cell.n_atom /
                                       structure_list[i].conventional_cell.get_volume(), temp[2],
                                       structure_list[i].conventional_cell.get_volume()))
                if l_comp:
                    if i_mode == 1 or i_mode == 3 or i_mode == 8: # Add by shenzx 20200530
                        d_f.write('{:>4} ({:>4}){:>14.5f}{:>15} {:>13} {:>12.4e} {:>13}\n'
                                  .format(i + 1, e_list[i][2], e_list[i][1], structure_list[i].space_group, i + 1,
                                          id_d[1], temp[2]))
                    elif i_mode == 2 or i_mode == 4:
                        d_f.write('{:>5}{:>15} {:>13} {:>12.4e} {:>13}\n'
                                  .format(i + 1, structure_list[i].space_group, i + 1, id_d[1], temp[2]))
                    for j, id_d2 in enumerate(struc_d[i + 1:]):
                        if id_d2[0] == i:
                            k = i + j + 1
                            temp = nele_ctype_fu(count_atoms(structure_list[k].conventional_cell.numbers))
                            if i_mode == 1 or i_mode == 3 or i_mode == 8: # Add by shenzx 20200530
                                d_f.write('{:>4} ({:>4}){:>14.5f}{:>15} {:>13} {:>12.4e} {:>13}\n'
                                          .format(k + 1, e_list[k][2], e_list[k][1], structure_list[k].space_group,
                                                  i + 1, id_d2[1], temp[2]))
                            elif i_mode == 2 or i_mode == 4:
                                d_f.write('{:>5}{:>15} {:>13} {:>12.4e} {:>13}\n'
                                          .format(k + 1, structure_list[k].space_group, i + 1, id_d2[1], temp[2]))
                    d_f.write('\n')
    anal.closed
    if l_comp:
        d_f.close()
    left_id = [i for i in range(total_struc) if struc_d[i][0] == -2]
    if l_db and total_struc > 0:
        # max_ls_e=e_list[0][1]+low_symm_er
        with connect('analyzed_structures.db', append=False) as db:
            data = {}
            data['spap_parameters'] = {
                'e_range': e_range, 'l_comp': l_comp, 'threshold': threshold, 'symprec': symprec, 'ilat': ilat,
                'r_cut_off': r_cut_off, 'extend_r': extend_r, 'ccf_step': ccf_step, 'total_struc': total_struc,
                'l_view': l_view, 'l_cif': l_cif, 'l_poscar': l_poscar, 'l_db': l_db}
            data['pseudopotential'] = pseudopotential
            if i_mode == 1:
                data['incar'] = ''
                for line in cal_p:
                    data['incar'] = data['incar'] + line
                data['prediction_parameters'] = ''
                for line in prediction_parameters:
                    data['prediction_parameters'] += line
                db.write(structure_list[left_id[0]].conventional_cell, relaxed=True, enthalpy=e_list[left_id[0]][1],
                         space_group=structure_list[left_id[0]].space_group, pressure=pressure,
                         prediction_method=prediction_method, experimental=False, opt_code=calculator,
                         data=data)
            elif i_mode == 2 or i_mode == 4:
                db.write(structure_list[left_id[0]].conventional_cell,
                         space_group=structure_list[left_id[0]].space_group, data=data)
            elif i_mode == 3:
                data['incar'] = ''
                for line in cal_p:
                    data['incar'] = data['incar'] + line
                db.write(structure_list[left_id[0]].conventional_cell, relaxed=True, e_per_a=e_list[left_id[0]][1],
                         space_group=structure_list[left_id[0]].space_group, pressure=pressure,
                         prediction_method=prediction_method, experimental=False, opt_code=calculator,
                         data=data)

            # Add by shenzx 20200530 --- need modified
            elif i_mode == 8:
                data['input'] = ''
                for line in cal_p:
                    data['incar'] = data['incar'] + line
                db.write(structure_list[left_id[0]].conventional_cell, relaxed=True, e_per_a=e_list[left_id[0]][1],
                         space_group=structure_list[left_id[0]].space_group, pressure=pressure,
                         prediction_method=prediction_method, experimental=False, opt_code=calculator,
                         data=data)

            if total_struc > 1:
                # for i in [j for j in left_id[1:] if space_g_l[j]>2 or e_list[j][1]<max_ls_e]:
                # for i in range(total_struc):
                for i in left_id[1:]:
                    if i_mode == 1:
                        db.write(structure_list[i].conventional_cell, relaxed=True, enthalpy=e_list[i][1],
                                 space_group=structure_list[i].space_group, pressure=pressure,
                                 experimental=False, opt_code=calculator)
                    elif i_mode == 2 or i_mode == 4:
                        db.write(structure_list[left_id[0]].conventional_cell,
                                 space_group=structure_list[left_id[0]].space_group)
                    elif i_mode == 3:
                        db.write(structure_list[i].conventional_cell, relaxed=True, e_per_a=e_list[i][1],
                                 space_group=structure_list[i].space_group, pressure=pressure,
                                 experimental=False, opt_code=calculator)
                    # Add by shenzx 20200530 --- need modified
                    elif i_mode == 8:
                        db.write(structure_list[i].conventional_cell, relaxed=True, e_per_a=e_list[i][1],
                                 space_group=structure_list[i].space_group, pressure=pressure,
                                 experimental=False, opt_code=calculator)
    if l_poscar:
        ctat = count_atoms(structure_list[left_id[0]].numbers, 2)
    for i in left_id:
        if lprm:
            prmc = standardize_cell(
                (structure_list[i].cell, structure_list[i].get_scaled_positions(wrap=True), structure_list[i].numbers),
                symprec=symprec, to_primitive=True)
            if prmc == None:
                prmc = structure_list[i]
            else:
                prmc = Atoms(cell=prmc[0], scaled_positions=prmc[1], numbers=prmc[2], pbc=pbc)
        if l_cif:
            write(dir_name + '/' + str(i + 1) + '_' + get_spg_n(structure_list[i].space_group) + '.cif',
                  structure_list[i].conventional_cell)
            if lprm:
                write(dir_name + '/' + str(i + 1) + '_' + get_spg_n(structure_list[i].space_group) + '_p.cif',
                      prmc)
        if l_poscar:
            write_struc(structure_list[i].conventional_cell, ctat,
                        dir_name + '/UCell_' + str(i + 1) + '_' + get_spg_n(structure_list[i].space_group) + '.vasp',
                        structure_list[i].space_group)
            if lprm:
                write_struc(prmc, ctat, dir_name + '/PCell_' + str(i + 1) + '_' + get_spg_n(
                    structure_list[i].space_group) + '.vasp', structure_list[i].space_group)
    if l_view:
        view([structure_list[i].conventional_cell for i in left_id])
    n_left = len(left_id)
    if n_left != 0:
        print('Multiplicity: {:6.3f}'.format(total_struc / n_left))
    print('Calculation succeeded')
    # return [[structure_list[i],e_list[i][1]] for i in left_id]
    # return structure_list,[e[1] for e in e_list]


def write_struc(struc, ctat, strucn, tag='generated by BDM'):
    poscar = open(strucn, 'w')
    poscar.write(tag + '\n1.0\n')
    for v in struc.cell:
        poscar.write('{:>13.7f}{:>13.7f}{:>13.7f}\n'.format(v[0], v[1], v[2]))
    smbd = getcf(struc.numbers, ctat, 2)
    ele_n = ''
    for smb in smbd.keys():
        poscar.write('{:4}'.format(smb))
        ele_n += ' {:>3}'.format(smbd[smb])
    poscar.write('\n' + ele_n + '\nDirect')
    scaled_pos = struc.get_scaled_positions(wrap=True)
    for n in ctat.keys():
        for j, pos in enumerate(scaled_pos):
            if struc.numbers[j] == n:
                poscar.write('\n{:>10.7f} {:>10.7f} {:>10.7f}'.format(pos[0], pos[1], pos[2]))
    poscar.close()


def getcf(numbers, ctat, irt=1):
    cf = ''
    smbd = {}
    for key in ctat.keys():
        for eles in atomic_numbers.keys():
            if atomic_numbers[eles] == key:
                cf += eles
                break
        ict = np.sum(numbers == key)
        smbd[eles] = ict
        if ict != 1:
            cf += str(ict)
    if irt == 1:
        return cf
    elif irt == 2:
        return smbd


def get_spg_n(spg):
    return spg[spg.index('(') + 1:-1]


def classify_structures(structures, space_groups, threshold, r_cut_off, ilat, r_vector, nsm=False, nfu=False):
    n = len(structures)
    struc_d = [[-1, 0.0] for i in range(n)]
    volume_dict = {}

    # for i in range(n):
    #     structures[i].n_atom = len(structures[i].numbers)
    for i in range(n):
        if struc_d[i][0] == -1:
            if ilat != 0 and (not structures[i].n_atom in volume_dict):
                volume_dict[structures[i].n_atom] = structures[i].get_volume()
            id_list = [i] + [x for x in range(i + 1, n) if
                             (struc_d[x][0] == -1) and ((space_groups[i] == space_groups[x]) or nsm) and
                             ((structures[i].n_atom == structures[x].n_atom) or nfu)]
            if ilat != 0:
                cal_struc_d(structures, id_list, struc_d, space_groups[i], threshold, r_cut_off,
                            volume_dict[structures[i].n_atom] / structures[i].n_atom, ilat, r_vector)
            else:
                cal_struc_d(structures, id_list, struc_d, space_groups[i], threshold, r_cut_off,
                            100.0, ilat, r_vector)
    return struc_d


def cal_struc_d(structures, id_list, struc_d, spg_n, threshold, r_cut_off, volume, ilat, r_vector):
    struc_d[id_list[0]][0] = -2
    prototype_id = [id_list[0]]
    # temp_c=[]
    if spg_n > 15 and spg_n < 195 and ilat == 2:
        l_same_cell = True
    else:
        l_same_cell = False
    if len(id_list) != 1:
        if ilat == 0 or volume == structures[id_list[0]].get_volume() / structures[id_list[0]].n_atom:
            structures[id_list[0]].ccf = struc2ccf(structures[id_list[0]], r_cut_off, r_vector)
            # if l_same_cell:
            #     temp_c=structures[id_list[0]].cell
        else:
            # temp_c=structures[id_list[0]].cell * (volume / structures[id_list[0]].get_volume()) ** (1.0 / 3.0)
            structures[id_list[0]].ccf = \
                struc2ccf(Atoms(
                    cell=structures[id_list[0]].cell * (
                            volume / structures[id_list[0]].get_volume() * structures[id_list[0]].n_atom) ** (
                                 1.0 / 3.0),
                    scaled_positions=structures[id_list[0]].get_scaled_positions(wrap=True),
                    numbers=structures[id_list[0]].numbers, pbc=structures[0].pbc), r_cut_off, r_vector)
        # volume = structures[id_list[0]].get_volume()
        for i in id_list[1:]:
            if ilat == 0:
                structures[i].ccf = struc2ccf(structures[i], r_cut_off, r_vector)
            else:
                scaled_positions = structures[i].get_scaled_positions(wrap=True)
                structures[i].ccf = struc2ccf(
                    Atoms(cell=structures[i].cell * (volume / structures[i].get_volume() * structures[i].n_atom) ** (
                            1.0 / 3.0),
                          scaled_positions=scaled_positions, numbers=structures[i].numbers, pbc=structures[0].pbc),
                    r_cut_off, r_vector)
            # if i == 57:
            #     ccf_file = open('ccf.dat', 'wb')
            #     pickle.dump(structures[i].ccf,ccf_file)
            #     ccf_file.close()
            #     rvf=open('rvf.dat','wb')
            #     pickle.dump(r_vector,rvf)
            #     rvf.close()
            #     write('out.cif',Atoms(cell=structures[i].cell * (volume / structures[i].get_volume()) ** (1.0 / 3.0),
            #           scaled_positions=scaled_positions, numbers=structures[i].numbers,pbc=structures[0].pbc))
            #     print(Atoms(cell=structures[i].cell * (volume / structures[i].get_volume()) ** (1.0 / 3.0),
            #           scaled_positions=scaled_positions, numbers=structures[i].numbers,pbc=structures[0].pbc).get_volume())
            for j in [prototype_id[-1 - j2] for j2 in range(len(prototype_id))]:
                struc_d[i][1] = cal_ccf_d(structures[j].ccf, structures[i].ccf)
                if struc_d[i][1] < threshold:
                    struc_d[i][0] = j
                    break
                elif l_same_cell:
                    struc_d[i][1] = cal_ccf_d(
                        structures[j].ccf, struc2ccf(Atoms(
                            cell=structures[j].cell * (volume / structures[j].get_volume() * structures[j].n_atom) ** (
                                    1.0 / 3.0),
                            scaled_positions=scaled_positions, numbers=structures[i].numbers,
                            pbc=structures[0].pbc), r_cut_off, r_vector))
                    if struc_d[i][1] < threshold:
                        struc_d[i][0] = j
                        break
            # Mark this structure as a new prototype.
            if struc_d[i][0] == -1:
                struc_d[i] = [-2, 0.0]
                prototype_id.append(i)


def count_atoms(numbers, imd=1):
    ctype = {}
    for i in numbers:
        if i in ctype:
            ctype[i] += 1
        else:
            ctype[i] = 1
    if imd == 1:
        return sorted(ctype.values())
    elif imd == 2:
        return ctype


def nele_ctype_fu(natom):
    if len(natom) == 0:
        return 0, '0', 0
    elif natom[0] == 0:
        return 0, '0', 0
    gcd = natom[0]
    lctype = len(natom)
    if lctype == 1:
        return 1, '1', gcd
    for i in natom[1:]:
        n1 = gcd
        n2 = i
        while True:
            gcd = n2 % n1
            if gcd == 0:
                gcd = n1
                break
            elif gcd == 1:
                return lctype, strctype(natom), 1
            else:
                n2 = n1
                n1 = gcd
    return lctype, strctype([int(float(i) / gcd + 0.5) for i in natom]), gcd


def strctype(ctype):
    sctype = str(ctype[0])
    if len(ctype) == 1:
        return sctype
    for i in ctype[1:]:
        sctype = sctype + '_' + str(i)
    return sctype


def plt_ccf(ccf, r_vector, ftype, ltt=True):
    stair = 0.0
    plt.title(ftype)
    hd = []
    lb = []
    if ltt:
        hd.append(plt.plot(r_vector, ccf['Total'], 'g-', linewidth=2))
        stair = max(0.3, 1.2 * max(ccf['Total']))
    for key in ccf.keys():
        if key != 'Total':
            hd.append(plt.plot(r_vector, ccf[key] + stair, 'g-', linewidth=2))
            # plt.legend(hd,key,loc='upper right')
            lb.append(key)
            # stair += 1.5 * max(ccf[key])
            stair += max(0.3, 1.2 * max(ccf[key]))
    plt.grid(True)
    # plt.legend(handles=hd,labels=lb,loc='best')
    plt.show()


def show2ccf(ccf1, ccf2, r_vector):
    stair = 0.0
    diff_ccf = {}
    for key in ccf1:
        diff_ccf[key] = ccf1[key] - ccf2[key]
        plt.plot(r_vector, diff_ccf[key] + stair, 'g-', linewidth=2)
        stair += 1.1 * max(diff_ccf[key])

        # plt.plot(r_vector, ccf1[key] + stair, 'g-', linewidth=2)
        # stair+=1.1*max(ccf1[key])
        # plt.plot(r_vector, ccf2[key] + stair, 'g-', linewidth=2)
        # stair += 1.1 * max(ccf2[key])
    plt.grid(True)
    plt.show()


def start_cli():
    helpl = '''
this parameter controls which method will be used to deal with lattice for comparing structural similarity
0 don't change lattice
1 equal particle number density
2 try equal particle number density and equal lattice (default: %(default)s)
'''.strip()
    parser = argparse.ArgumentParser(
        description='SPAP can analyze symmetry and compare similarity of a large number of atomic structures. '
                    'Typically, SPAP can process structures predicted by CALYPSO (www.calypso.cn).'
        # 'Coordination Characterization Function (CCF) is used to measure structural '
        # 'similarity. If you use this program and method in your research, please read and cite the '
        # 'following publication: \nJ. Phys. Condens. Matter 2017, 29, 165901.'
    )
    parser.add_argument('-t', '--tolerance', '--symprec', type=float, default=0.1, dest='symprec',
                        help='this precision is used to analyze symmetry of atomic structures (default: %(default)s)')
    parser.add_argument('-e', '--e_range', type=float, default=0.4,
                        help='define an energy range in which structures will be analyzed (default: %(default)s)')
    parser.add_argument('-n', '--total_struc', type=int, default=None,
                        help='this number of structures will be analyzed  (default: %(default)s)')
    parser.add_argument('-a', action='store_true', help='process all the structures')
    parser.add_argument('--nc', '--n_comp', action='store_true',
                        help='not to compare similarity of structures  (default: %(default)s)')
    parser.add_argument('--th', '--threshold', type=float, default=None, dest='threshold',
                        help='threshold for similar/dissimilar boundary (default: %(default)s)')
    parser.add_argument('-r', '--r_cut_off', type=float, default=None,
                        help='inter-atomic distances within this cut off radius will contribute to CCF '
                             '(default: %(default)s Angstrom)')
    parser.add_argument('-l', '--ilat', type=int, choices=[0, 1, 2], default=2,
                        help=helpl
                        # 'this parameter controls which method will be used to '
                        # 'deal with lattice for comparing structural similarity\n'
                        # '0 don\'t change lattice\n'
                        # '1 equal particle number density\n'
                        # '2 try equal particle number density and equal lattice (default: %(default)s)'
                        )
    parser.add_argument('--nd', '--no_db', action='store_true',
                        help='not to write structures into ase (https://wiki.fysik.dtu.dk/ase/) database file '
                             '(default: %(default)s)')
    parser.add_argument('--cif', '--l_cif', action='store_true', dest='l_cif',
                        help='write structures into cif files (default: %(default)s)')
    parser.add_argument('--pos', '--vasp', '--l_poscar', action='store_true', dest='l_poscar',
                        help='write structures into files in VASP format (default: %(default)s)')
    parser.add_argument('-d', '--l_view', action='store_true', help='display the structures (default: %(default)s)')
    parser.add_argument('-w', '--work_dir', type=str, default='./', help='set working directory (default: %(default)s)')
    parser.add_argument('-i', '--i_mode', type=int, choices=[1, 2, 3], default=1,
                        help='different functionality of SPAP: \n1 analyze CALYPSO prediction results; \n2 calculate '
                             'symmetry and similarity of structures in struc directory; \n3 read and analyze '
                             'structures optimized by VASP (default: %(default)s)')
    parser.add_argument('-v', '--version', action='version', version='SPAP: 1.0.2')
    args = parser.parse_args()
    if args.a:
        args.total_struc = 99999999
    run_spap(
        symprec=args.symprec,
        e_range=args.e_range,
        # e_range=0.3,
        total_struc=args.total_struc,
        # threshold=0.05,
        threshold=args.threshold,
        # r_cut_off=6.0,
        r_cut_off=args.r_cut_off,
        # extend_r=1.0,
        # ilat=2,
        ilat=args.ilat,
        # ccf_step=0.02,
        # l_comp=False,
        l_comp=not args.nc,
        # l_db=True,
        l_db=not args.nd,
        # l_cif=True,
        l_cif=args.l_cif,
        # l_poscar=True,
        l_poscar=args.l_poscar,
        # work_dir='./example/results',
        # work_dir='C:\\Users\\null\\Documents\\share\\wks\\Examples\\1_example\\results',
        # work_dir='./results',
        work_dir=args.work_dir,
        # i_mode=1,
        i_mode=args.i_mode,
        # l_view=True,
        l_view=args.l_view,
    )


if __name__ == '__main__':
    start_cli()
